#!/usr/bin/python
# -*- coding: UTF-8 -*-
# author : bird.zhang@ximalaya.com

import base64
from Crypto.Cipher import AES
from . import constant

key_len_16 = 16
key_len_24 = 24
key_len_32 = 32


def __convert_to_bytes(object_):
    if bytes != type(object_):
        bytes_ = bytes(str(object_), constant.CHARSET_UTF8)
    else:
        bytes_ = object_
    return bytes_


def refactor_key(key):
    """
    这里密钥key 长度必须为16（AES-128）、24（AES-192）、或32（AES-256）Bytes 长度.目前AES-128足够用
    :param key: key
    :return: 规约化后的key，bytes
    """

    text_bytes = __convert_to_bytes(key)

    # 不能超过32位
    count = len(text_bytes)
    if 0 == count:
        return b'\0' * key_len_16
    elif 0 < count < key_len_16:
        return text_bytes + (b'\0' * (key_len_16 - count))
    elif key_len_16 == count:
        return text_bytes
    elif key_len_16 < count < key_len_24:
        return text_bytes + (b'\0' * (key_len_24 - count))
    elif key_len_24 == count:
        return text_bytes
    elif key_len_24 < count < key_len_32:
        return text_bytes + (b'\0' * (key_len_32 - count))
    else:
        return text_bytes[:key_len_32]


def refactor_message(message):
    """
    重构message，将其变成16的倍数的bytes，默认utf-8编码
    :param message: 明文
    :return: 长度是16倍数的bytes
    """

    # if bytes != type(message):
    #     message_bytes = bytes(str(message), constant.CHARSET_UTF8)
    # else:
    #     message_bytes = message

    message_bytes = __convert_to_bytes(message)

    length = key_len_16
    count = len(message_bytes)
    if 0 == count % length:
        return message_bytes
    else:
        add = length - (count % length)
        message_bytes = message_bytes + (b'\0' * add)
        return message_bytes


class Cryptor:
    """
    封装AES.MODE_CBC
    """

    version = 1

    def __init__(self, key):
        self.key = refactor_key(key)
        self.iv = self.key[:key_len_16]
        self.mode = AES.MODE_CBC

    def encrypt(self, message):
        cryptor = AES.new(self.key, self.mode, self.iv)
        message = refactor_message(message)
        ciphertext = cryptor.encrypt(message)
        # 加密后的密文是bytes
        # print(type(ciphertext))
        # 因为AES加密时候得到的字符串不一定是ascii字符集的，输出到终端或者保存时候可能存在问题
        # 所以这里统一把加密后的字符串转化为16进制字符串
        return str(base64.b64encode(ciphertext), constant.CHARSET_UTF8)
        # return b2a_hex(ciphertext).decode("ASCII")

    def decrypt(self, ciphertext):
        cryptor = AES.new(self.key, self.mode, self.iv)
        # 解密后，去掉补足的空格用strip() 去掉
        plain_text = cryptor.decrypt(base64.b64decode(ciphertext))
        return plain_text.rstrip(b'\0').decode(constant.CHARSET_UTF8)


def check_auth(message, ciphertext, key):
    """
    检测密钥是否正确。
    通过检测一段明文和密文，来判断密钥的正确性。
    """

    if message is None or '' == message:
        return False

    if ciphertext is None or '' == ciphertext:
        return False

    if ciphertext != Cryptor(key).encrypt(message):
        return False

    return True


class AuthFailedException(Exception):
    def __init__(self, *args, **kwargs):
        pass
